name: wheels_build

on:
  workflow_call:
    inputs:
      os:
        required: true
        type: string
      python:
        required: true
        type: string
      torch_version:
        required: true
        type: string
        description: "Example: 1.13.1"
      cuda_short_version:
        required: true
        type: string
        description: "Example: 117 for 11.7"
      artifact_tag:
        default: "facebookresearch"
        type: string

# this yaml file can be cleaned up using yaml anchors, but they're not supported in github actions yet
# https://github.com/actions/runner/issues/1182

env:
  # you need at least cuda 5.0 for some of the stuff compiled here.
  TORCH_CUDA_ARCH_LIST: "5.0+PTX 6.0 6.1 7.0 7.5 8.0+PTX"
  MAX_JOBS: 4
  DISTUTILS_USE_SDK: 1 # otherwise distutils will complain on windows about multiple versions of msvc
  XFORMERS_BUILD_TYPE: "Release"
  TWINE_USERNAME: __token__
  XFORMERS_PACKAGE_FROM: "wheel-${{ github.ref_name }}"

jobs:
  build:
    name: ${{ contains(inputs.os, 'ubuntu') && 'ubuntu' || 'win' }}-py${{ inputs.python }}-pt${{ inputs.torch_version }}+cu${{ inputs.cuda_short_version }}
    runs-on: ${{ inputs.os }}
    env:
      # alias for the current python version
      # windows does not have per version binary, it is just 'python3'
      PY: python${{ contains(inputs.os, 'ubuntu') && inputs.python || '3' }}

    container: ${{ contains(inputs.os, 'ubuntu') && 'quay.io/pypa/manylinux2014_x86_64' || null }}
    timeout-minutes: 360
    defaults:
      run:
        shell: bash
    steps:
      - id: cuda_info
        shell: python
        run: |
          import os
          import sys
          print(sys.version)
          cushort = "${{ inputs.cuda_short_version }}"
          TORCH_CUDA_DEFAULT = "121"  # pytorch 2.1.0
          # https://github.com/Jimver/cuda-toolkit/blob/master/src/links/linux-links.ts
          full_version, install_script = {
            "121": ("12.1.0", "https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda_12.1.0_530.30.02_linux.run"),
            "118": ("11.8.0", "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"),
            "117": ("11.7.1", "https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda_11.7.1_515.65.01_linux.run"),
            "116": ("11.6.2", "https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda_11.6.2_510.47.03_linux.run"),
          }[cushort]
          with open(os.environ['GITHUB_OUTPUT'], "r+") as fp:
            fp.write("CUDA_VERSION=" + full_version + "\n")
            if cushort == TORCH_CUDA_DEFAULT:
              fp.write("CUDA_VERSION_SUFFIX=\n")
              fp.write("TORCH_ORG_S3_PATH=s3://pytorch/whl\n")
              fp.write("PUBLISH_PYPI=1\n")
            else:
              fp.write("CUDA_VERSION_SUFFIX=+cu" + cushort + "\n")
              fp.write("TORCH_ORG_S3_PATH=s3://pytorch/whl/" + cushort + "\n")
              fp.write("PUBLISH_PYPI=0\n")
            fp.write("CUDA_INSTALL_SCRIPT=" + install_script + "\n")
      - run: echo "CUDA_VERSION_SUFFIX=${{ steps.cuda_info.outputs.CUDA_VERSION_SUFFIX }}"
      - run: echo "TORCH_ORG_S3_PATH=${{ steps.cuda_info.outputs.TORCH_ORG_S3_PATH }}"
      - run: echo "PUBLISH_PYPI=${{ steps.cuda_info.outputs.PUBLISH_PYPI }}"

      - name: Add H100 if nvcc 11.08+
        shell: python
        run: |
          import os
          import sys
          print(sys.version)
          cuda_short_version = "${{ inputs.cuda_short_version }}"
          arch_list = os.environ["TORCH_CUDA_ARCH_LIST"]
          if cuda_short_version not in ["116", "117"]:
            arch_list += " 9.0"
          with open(os.environ['GITHUB_ENV'], "r+") as fp:
            fp.write("TORCH_CUDA_ARCH_LIST=" + arch_list + "\n")
      - run: echo "${TORCH_CUDA_ARCH_LIST}"

      - name: Recursive checkout
        uses: actions/checkout@v3
        with:
          submodules: recursive
          path: "."
          fetch-depth: 0 # for tags

      - if: runner.os != 'Windows'
        name: (Linux) Setup venv for linux
        run: |
          $PY -m venv venv
          . ./venv/bin/activate
          which pip
          echo "PY=$(which python)" >> ${GITHUB_ENV}
          echo "PATH=$PATH" >> ${GITHUB_ENV}
          echo "MAX_JOBS=3" >> ${GITHUB_ENV}

      - name: Define version
        id: xformers_version
        env:
          VERSION_SOURCE: ${{ github.ref_type == 'tag' && 'tag' || 'dev'  }}
        run: |
          set -Eeuo pipefail
          git config --global --add safe.directory "*"
          pip install packaging ninja
          version=`python packaging/compute_wheel_version.py --source $VERSION_SOURCE`
          echo $version > version.txt
          echo "BUILD_VERSION=$version${{ steps.cuda_info.outputs.CUDA_VERSION_SUFFIX }}" >> ${GITHUB_ENV}
          echo "BUILD_VERSION=$version${{ steps.cuda_info.outputs.CUDA_VERSION_SUFFIX }}" >> ${GITHUB_OUTPUT}
          which ninja
          cat ${GITHUB_ENV}
      - run: echo "xformers-${BUILD_VERSION}"
      - run: echo "release version (will upload to PyTorch)"
        if: ${{ !contains(steps.xformers_version.outputs.BUILD_VERSION, '.dev') }}

      - name: Setup proper pytorch dependency in "requirements.txt"
        run: |
          sed -i '/torch/d' ./requirements.txt
          echo "torch == ${{ inputs.torch_version }}" >> ./requirements.txt
          cat ./requirements.txt

      - if: runner.os == 'Windows'
        name: (Windows) Setup Runner
        uses: ./.github/actions/setup-windows-runner
        with:
          cuda: ${{ steps.cuda_info.outputs.CUDA_VERSION }}
          python: ${{ inputs.python }}

      - name: Install dependencies
        run: $PY -m pip install wheel setuptools twine -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cuda_short_version }}

      - if: runner.os == 'Linux'
        name: (Linux) install cuda
        run: >
          yum install wget git prename -y &&
          wget -q "${{ steps.cuda_info.outputs.CUDA_INSTALL_SCRIPT }}" -O cuda.run &&
          sh ./cuda.run --silent --toolkit &&
          rm ./cuda.run

      - name: Build wheel
        run: |
          $PY setup.py bdist_wheel -d dist/ -k $PLAT_ARG
        env:
          PLAT_ARG: ${{ contains(inputs.os, 'ubuntu') && '--plat-name manylinux2014_x86_64' || '' }}

      - run: du -h dist/*
      - uses: actions/upload-artifact@v3
        with:
          name: ${{ inputs.os }}-py${{ inputs.python }}-torch${{ inputs.torch_version }}+cu${{ inputs.cuda_short_version }}_${{ inputs.artifact_tag }}
          path: dist/*.whl
# Note: it might be helpful to have additional steps that test if the built wheels actually work
